package com.bosssoft.platform.fasttcc.support.resttemplate;

import com.bosssoft.platform.fasttcc.TccTransactionManager;
import com.bosssoft.platform.fasttcc.Xid;
import com.bosssoft.platform.fasttcc.rpc.command.TccMethodCallCommand;
import com.bosssoft.platform.fasttcc.rpc.filter.RpcFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.cloud.client.loadbalancer.LoadBalanced;
import org.springframework.cloud.client.loadbalancer.LoadBalancerInterceptor;
import org.springframework.cloud.client.loadbalancer.RestTemplateCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.support.HttpRequestWrapper;
import org.springframework.web.client.RestTemplate;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;

@Configuration
public class ConfigRestTempleteInterceptor
{
    @Bean
    @LoadBalanced
    public RestTemplate restTemplate()
    {
        return new RestTemplate();
    }

    @Value("${spring.application.name}")
    private String applicationName;


    @Bean
    public TccRestTempleteInteceptor tccRestTempleteInteceptor(RpcFilter rpcFilter)
    {
        return new TccRestTempleteInteceptor(rpcFilter, applicationName);
    }

    @Bean
    public RestTemplateCustomizer restTemplateCustomizer(final LoadBalancerInterceptor loadBalancerInterceptor, final TccRestTempleteInteceptor tccRestTempleteInteceptor, TccTransactionManager tccTransactionManager)
    {
        return new RestTemplateCustomizer()
        {
            @Override
            public void customize(RestTemplate restTemplate)
            {
                List<ClientHttpRequestInterceptor> interceptors = new ArrayList<ClientHttpRequestInterceptor>(restTemplate.getInterceptors());
                interceptors.add(loadBalancerInterceptor);
                interceptors.add(tccRestTempleteInteceptor);
                restTemplate.setInterceptors(interceptors);
            }
        };
    }

    static class TccRestTempleteInteceptor implements ClientHttpRequestInterceptor
    {
        static final         String    CALLID         = "FASTTCC-CALLID";
        static final         String    NODEIDENTIFIER = "FASTTCC-NODEIDENTIFIER";
        static final         String    XID            = "FASTTCC-XID";
        private static final Logger    LOGGER         = LoggerFactory.getLogger(TccRestTempleteInteceptor.class);
        private              RpcFilter rpcFilter;
        private              String    applicationName;

        public TccRestTempleteInteceptor(RpcFilter rpcFilter, String applicationName)
        {
            this.rpcFilter = rpcFilter;
            this.applicationName = applicationName;
        }

        @Override
        public ClientHttpResponse intercept(HttpRequest httpRequest, byte[] bytes, ClientHttpRequestExecution clientHttpRequestExecution) throws IOException
        {
            URI uri = httpRequest.getURI();
            if (uri.getScheme().equals("http"))
            {
                String ip   = uri.getHost();
                int    port = uri.getPort();
                String instanceName;
                if (httpRequest instanceof HttpRequestWrapper)
                {
                    instanceName = ((HttpRequestWrapper) httpRequest).getRequest().getURI().getHost();
                }
                else
                {
                    instanceName = "unknow";
                }
                TccMethodCallCommand tccMethodCallCommand = rpcFilter.beforeSendRequest(ip + ":" + port + ":" + instanceName);
                if (tccMethodCallCommand != null)
                {
                    String callId         = tccMethodCallCommand.getCallId();
                    String nodeIdentifier = tccMethodCallCommand.getNodeIdentifier();
                    Xid    xid            = tccMethodCallCommand.getXid();
                    httpRequest.getHeaders().add(CALLID, callId);
                    httpRequest.getHeaders().add(NODEIDENTIFIER, nodeIdentifier);
                    httpRequest.getHeaders().add(XID, xid.toString());
                }
            }
            return clientHttpRequestExecution.execute(httpRequest, bytes);
        }
    }
}
